iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 32

JAX 好好玩 (32) : 綜合演練 – 預測 MNIST

  • 分享至 

  • xImage
  •  

JAX 官方文件中有一個很好的範例 (Training a Simple Neural Network, with PyTorch Data Loading) ,老頭將其稍稍改寫了一下,並加上一些註解,放在這裏,讀者們可以下載來執行看看。

這個例子以 Pytorch 來載入 MNIST 資料集,並將其轉換為 DeviceArray 格式。JAX 並沒有定義處理 dataset 相關的 API,當然也沒有像 Pytorch 和 TensorFlow 一般包裝好了一些著名的 dataset 供人使用。因此,有關於資料集的處理,JAX 程式設計師必須仰賴既有的 AI 框架所提供的服務。總之,只要這些資料最終能轉為 Numpy 陣列格式, JAX 就可以使用它。

這個例子也介紹了如何使用 vmap 自動向量化的功能。它的重點是:

  • 模型的 predict() 函式,只接受單一筆的訓練資料。
  • 另外用 vmap 轉換 predict,並指定那一個參數的那一個維度來做自動向量化。
  • 接下來,損失函式就要以 vmap 轉換過的函式為依據來計算。
# 定義模型預測函式 : 單一 image 預測
# ==================================================================
def predict(params, image):
    activations = image
    for layer in params[:-1]:
        w = layer['w']
        b = layer['b']
        outputs = jnp.dot(activations,w) + b
        activations = relu(outputs)
  
    final_w, final_b = (params[-1]['w'],params[-1]['b'])
    logits = jnp.dot(activations,final_w) + final_b
    return jnp.exp(logits) / jnp.sum(jnp.exp(logits)) # return softmax
 
# 自動批次的預測函式
# ==================================================================
#  predict(params, image)
#       params: 不做 auto vectorization (對應 in_axes 的 None)
#       image:  對第一維度做 auto vectorization (對應 in_axes 的 0)
 
batched_predict = jax.vmap(predict, in_axes=(None, 0))
# 損失函式
# ==================================================================
#   MSE (Mean Squared Error)
 
def loss(params, images, targets):
    preds = batched_predict(params, images) # 要參考自動向量化的預測函式版本
    return jnp.mean((targets-preds)**2)
 
 
if CFG_UnitTest:
    random_flattened_images = jrand.normal(jrand.PRNGKey(1), (2, 28 * 28))
    random_targets = jnp.array([[1.,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,1.]], dtype=jnp.float32)
    print(loss(HP_Params, random_flattened_images, random_targets))

另外,老頭在這個範例程式中,加入了我慣用的「單元測試 unit test 」的手法。在重要的函式及程式片斷後,用一個 if 控制結構 (參考上面的程式片斷 if CFG_UnitTest: 部份) 來包裝單元測試程式段。在程式開發的過程中,我會將 CFG_UnitTest 設為 True,每寫完一段程式馬上做測試。等到開發完成,再將 CFG_UnitTest 設為 False ,程式執行時就會自動略過單元測試。這個部份也提出來供大家參考。

OK, 大家可以直接去跑老頭提供的 colab 筆記本,Good Luck!


上一篇
JAX 好好玩 (31) : 綜合演練 – 線性迴歸
下一篇
JAX 好好玩 (33) : 類別與 jit (1) : 重新定義 hash
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言